
<!doctype html>
<html lang="en" class="no-js">
  <head>
    
      <meta charset="utf-8">
      <meta name="viewport" content="width=device-width,initial-scale=1">
      
      
      
        <link rel="canonical" href="https://pytorch-widedeep.readthedocs.io/pytorch-widedeep/callbacks.html">
      
      
        <link rel="prev" href="dataloaders.html">
      
      
        <link rel="next" href="trainer.html">
      
      
      <link rel="icon" href="../assets/images/favicon.ico">
      <meta name="generator" content="mkdocs-1.6.1, mkdocs-material-9.5.43">
    
    
      
        <title>Callbacks - pytorch_widedeep</title>
      
    
    
      <link rel="stylesheet" href="../assets/stylesheets/main.0253249f.min.css">
      
        
        <link rel="stylesheet" href="../assets/stylesheets/palette.06af60db.min.css">
      
      


    
    
      
    
    
      
        
        
        <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
        <link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300i,400,400i,700,700i%7CRoboto+Mono:400,400i,700,700i&display=fallback">
        <style>:root{--md-text-font:"Roboto";--md-code-font:"Roboto Mono"}</style>
      
    
    
      <link rel="stylesheet" href="../assets/_mkdocstrings.css">
    
      <link rel="stylesheet" href="../stylesheets/extra.css">
    
    <script>__md_scope=new URL("..",location),__md_hash=e=>[...e].reduce(((e,_)=>(e<<5)-e+_.charCodeAt(0)),0),__md_get=(e,_=localStorage,t=__md_scope)=>JSON.parse(_.getItem(t.pathname+"."+e)),__md_set=(e,_,t=localStorage,a=__md_scope)=>{try{t.setItem(a.pathname+"."+e,JSON.stringify(_))}catch(e){}}</script>
    
      

    
    
    
  </head>
  
  
    
    
      
    
    
    
    
    <body dir="ltr" data-md-color-scheme="default" data-md-color-primary="red" data-md-color-accent="deep-orange">
  
    
    <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer" autocomplete="off">
    <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search" autocomplete="off">
    <label class="md-overlay" for="__drawer"></label>
    <div data-md-component="skip">
      
        
        <a href="#callbacks" class="md-skip">
          Skip to content
        </a>
      
    </div>
    <div data-md-component="announce">
      
    </div>
    
    
      

  

<header class="md-header md-header--shadow md-header--lifted" data-md-component="header">
  <nav class="md-header__inner md-grid" aria-label="Header">
    <a href="../index.html" title="pytorch_widedeep" class="md-header__button md-logo" aria-label="pytorch_widedeep" data-md-component="logo">
      
  <img src="../assets/images/widedeep_logo.png" alt="logo">

    </a>
    <label class="md-header__button md-icon" for="__drawer">
      
      <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M3 6h18v2H3zm0 5h18v2H3zm0 5h18v2H3z"/></svg>
    </label>
    <div class="md-header__title" data-md-component="header-title">
      <div class="md-header__ellipsis">
        <div class="md-header__topic">
          <span class="md-ellipsis">
            pytorch_widedeep
          </span>
        </div>
        <div class="md-header__topic" data-md-component="header-topic">
          <span class="md-ellipsis">
            
              Callbacks
            
          </span>
        </div>
      </div>
    </div>
    
      
        <form class="md-header__option" data-md-component="palette">
  
    
    
    
    <input class="md-option" data-md-color-media="" data-md-color-scheme="default" data-md-color-primary="red" data-md-color-accent="deep-orange"  aria-label="Switch to dark mode"  type="radio" name="__palette" id="__palette_0">
    
      <label class="md-header__button md-icon" title="Switch to dark mode" for="__palette_1" hidden>
        <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 8a4 4 0 0 0-4 4 4 4 0 0 0 4 4 4 4 0 0 0 4-4 4 4 0 0 0-4-4m0 10a6 6 0 0 1-6-6 6 6 0 0 1 6-6 6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
      </label>
    
  
    
    
    
    <input class="md-option" data-md-color-media="" data-md-color-scheme="slate" data-md-color-primary="red" data-md-color-accent="deep-orange"  aria-label="Switch to light mode"  type="radio" name="__palette" id="__palette_1">
    
      <label class="md-header__button md-icon" title="Switch to light mode" for="__palette_0" hidden>
        <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 18c-.89 0-1.74-.2-2.5-.55C11.56 16.5 13 14.42 13 12s-1.44-4.5-3.5-5.45C10.26 6.2 11.11 6 12 6a6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
      </label>
    
  
</form>
      
    
    
      <script>var palette=__md_get("__palette");if(palette&&palette.color){if("(prefers-color-scheme)"===palette.color.media){var media=matchMedia("(prefers-color-scheme: light)"),input=document.querySelector(media.matches?"[data-md-color-media='(prefers-color-scheme: light)']":"[data-md-color-media='(prefers-color-scheme: dark)']");palette.color.media=input.getAttribute("data-md-color-media"),palette.color.scheme=input.getAttribute("data-md-color-scheme"),palette.color.primary=input.getAttribute("data-md-color-primary"),palette.color.accent=input.getAttribute("data-md-color-accent")}for(var[key,value]of Object.entries(palette.color))document.body.setAttribute("data-md-color-"+key,value)}</script>
    
    
    
      <label class="md-header__button md-icon" for="__search">
        
        <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
      </label>
      <div class="md-search" data-md-component="search" role="dialog">
  <label class="md-search__overlay" for="__search"></label>
  <div class="md-search__inner" role="search">
    <form class="md-search__form" name="search">
      <input type="text" class="md-search__input" name="query" aria-label="Search" placeholder="Search" autocapitalize="off" autocorrect="off" autocomplete="off" spellcheck="false" data-md-component="search-query" required>
      <label class="md-search__icon md-icon" for="__search">
        
        <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
        
        <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
      </label>
      <nav class="md-search__options" aria-label="Search">
        
        <button type="reset" class="md-search__icon md-icon" title="Clear" aria-label="Clear" tabindex="-1">
          
          <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M19 6.41 17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12z"/></svg>
        </button>
      </nav>
      
    </form>
    <div class="md-search__output">
      <div class="md-search__scrollwrap" tabindex="0" data-md-scrollfix>
        <div class="md-search-result" data-md-component="search-result">
          <div class="md-search-result__meta">
            Initializing search
          </div>
          <ol class="md-search-result__list" role="presentation"></ol>
        </div>
      </div>
    </div>
  </div>
</div>
    
    
      <div class="md-header__source">
        <a href="https://github.com/jrzaurin/pytorch-widedeep" title="Go to repository" class="md-source" data-md-component="source">
  <div class="md-source__icon md-icon">
    
    <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 6.6.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2024 Fonticons, Inc.--><path d="M439.55 236.05 244 40.45a28.87 28.87 0 0 0-40.81 0l-40.66 40.63 51.52 51.52c27.06-9.14 52.68 16.77 43.39 43.68l49.66 49.66c34.23-11.8 61.18 31 35.47 56.69-26.49 26.49-70.21-2.87-56-37.34L240.22 199v121.85c25.3 12.54 22.26 41.85 9.08 55a34.34 34.34 0 0 1-48.55 0c-17.57-17.6-11.07-46.91 11.25-56v-123c-20.8-8.51-24.6-30.74-18.64-45L142.57 101 8.45 235.14a28.86 28.86 0 0 0 0 40.81l195.61 195.6a28.86 28.86 0 0 0 40.8 0l194.69-194.69a28.86 28.86 0 0 0 0-40.81"/></svg>
  </div>
  <div class="md-source__repository">
    pytorch_widedeep
  </div>
</a>
      </div>
    
  </nav>
  
    
      
<nav class="md-tabs" aria-label="Tabs" data-md-component="tabs">
  <div class="md-grid">
    <ul class="md-tabs__list">
      
        
  
  
  
    <li class="md-tabs__item">
      <a href="../index.html" class="md-tabs__link">
        
  
    
  
  Home

      </a>
    </li>
  

      
        
  
  
  
    <li class="md-tabs__item">
      <a href="../installation.html" class="md-tabs__link">
        
  
    
  
  Installation

      </a>
    </li>
  

      
        
  
  
  
    <li class="md-tabs__item">
      <a href="../quick_start.html" class="md-tabs__link">
        
  
    
  
  Quick Start

      </a>
    </li>
  

      
        
  
  
    
  
  
    
    
      
  
  
    
  
  
    
    
      <li class="md-tabs__item md-tabs__item--active">
        <a href="utils/index.html" class="md-tabs__link">
          
  
    
  
  Pytorch-widedeep

        </a>
      </li>
    
  

    
  

      
        
  
  
  
    
    
      <li class="md-tabs__item">
        <a href="../examples/01_preprocessors_and_utils.html" class="md-tabs__link">
          
  
    
  
  Examples

        </a>
      </li>
    
  

      
        
  
  
  
    <li class="md-tabs__item">
      <a href="../contributing.html" class="md-tabs__link">
        
  
    
  
  Contributing

      </a>
    </li>
  

      
    </ul>
  </div>
</nav>
    
  
</header>
    
    <div class="md-container" data-md-component="container">
      
      
        
      
      <main class="md-main" data-md-component="main">
        <div class="md-main__inner md-grid">
          
            
              
              <div class="md-sidebar md-sidebar--primary" data-md-component="sidebar" data-md-type="navigation" >
                <div class="md-sidebar__scrollwrap">
                  <div class="md-sidebar__inner">
                    


  


  

<nav class="md-nav md-nav--primary md-nav--lifted md-nav--integrated" aria-label="Navigation" data-md-level="0">
  <label class="md-nav__title" for="__drawer">
    <a href="../index.html" title="pytorch_widedeep" class="md-nav__button md-logo" aria-label="pytorch_widedeep" data-md-component="logo">
      
  <img src="../assets/images/widedeep_logo.png" alt="logo">

    </a>
    pytorch_widedeep
  </label>
  
    <div class="md-nav__source">
      <a href="https://github.com/jrzaurin/pytorch-widedeep" title="Go to repository" class="md-source" data-md-component="source">
  <div class="md-source__icon md-icon">
    
    <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 6.6.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2024 Fonticons, Inc.--><path d="M439.55 236.05 244 40.45a28.87 28.87 0 0 0-40.81 0l-40.66 40.63 51.52 51.52c27.06-9.14 52.68 16.77 43.39 43.68l49.66 49.66c34.23-11.8 61.18 31 35.47 56.69-26.49 26.49-70.21-2.87-56-37.34L240.22 199v121.85c25.3 12.54 22.26 41.85 9.08 55a34.34 34.34 0 0 1-48.55 0c-17.57-17.6-11.07-46.91 11.25-56v-123c-20.8-8.51-24.6-30.74-18.64-45L142.57 101 8.45 235.14a28.86 28.86 0 0 0 0 40.81l195.61 195.6a28.86 28.86 0 0 0 40.8 0l194.69-194.69a28.86 28.86 0 0 0 0-40.81"/></svg>
  </div>
  <div class="md-source__repository">
    pytorch_widedeep
  </div>
</a>
    </div>
  
  <ul class="md-nav__list" data-md-scrollfix>
    
      
      
  
  
  
  
    <li class="md-nav__item">
      <a href="../index.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Home
  </span>
  

      </a>
    </li>
  

    
      
      
  
  
  
  
    <li class="md-nav__item">
      <a href="../installation.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Installation
  </span>
  

      </a>
    </li>
  

    
      
      
  
  
  
  
    <li class="md-nav__item">
      <a href="../quick_start.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Quick Start
  </span>
  

      </a>
    </li>
  

    
      
      
  
  
    
  
  
  
    
    
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
    
    
      
        
        
      
      
    
    
    <li class="md-nav__item md-nav__item--active md-nav__item--section md-nav__item--nested">
      
        
        
        <input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_4" checked>
        
          
          <label class="md-nav__link" for="__nav_4" id="__nav_4_label" tabindex="">
            
  
  <span class="md-ellipsis">
    Pytorch-widedeep
  </span>
  

            <span class="md-nav__icon md-icon"></span>
          </label>
        
        <nav class="md-nav" data-md-level="1" aria-labelledby="__nav_4_label" aria-expanded="true">
          <label class="md-nav__title" for="__nav_4">
            <span class="md-nav__icon md-icon"></span>
            Pytorch-widedeep
          </label>
          <ul class="md-nav__list" data-md-scrollfix>
            
              
                
  
  
  
  
    
    
      
        
          
        
      
        
      
        
      
        
      
        
      
    
    
      
      
    
    
    <li class="md-nav__item md-nav__item--nested">
      
        
        
          
        
        <input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_4_1" >
        
          
          
          <div class="md-nav__link md-nav__container">
            <a href="utils/index.html" class="md-nav__link ">
              
  
  <span class="md-ellipsis">
    Utils
  </span>
  

            </a>
            
              
              <label class="md-nav__link " for="__nav_4_1" id="__nav_4_1_label" tabindex="0">
                <span class="md-nav__icon md-icon"></span>
              </label>
            
          </div>
        
        <nav class="md-nav" data-md-level="2" aria-labelledby="__nav_4_1_label" aria-expanded="false">
          <label class="md-nav__title" for="__nav_4_1">
            <span class="md-nav__icon md-icon"></span>
            Utils
          </label>
          <ul class="md-nav__list" data-md-scrollfix>
            
              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="utils/deeptabular_utils.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Deeptabular utils
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="utils/fastai_transforms.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Fastai transforms
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="utils/image_utils.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Image utils
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="utils/text_utils.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Text utils
  </span>
  

      </a>
    </li>
  

              
            
          </ul>
        </nav>
      
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="preprocessing.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Preprocessing
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="load_from_folder.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Load From Folder
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="model_components.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Model Components
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="the_rec_module.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    The Rec Module
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="bayesian_models.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Bayesian models
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="losses.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Losses
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="metrics.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Metrics
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="dataloaders.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Dataloaders
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
    
  
  
  
    <li class="md-nav__item md-nav__item--active">
      
      <input class="md-nav__toggle md-toggle" type="checkbox" id="__toc">
      
      
        
      
      
        <label class="md-nav__link md-nav__link--active" for="__toc">
          
  
  <span class="md-ellipsis">
    Callbacks
  </span>
  

          <span class="md-nav__icon md-icon"></span>
        </label>
      
      <a href="callbacks.html" class="md-nav__link md-nav__link--active">
        
  
  <span class="md-ellipsis">
    Callbacks
  </span>
  

      </a>
      
        

<nav class="md-nav md-nav--secondary" aria-label="Table of contents">
  
  
  
    
  
  
    <label class="md-nav__title" for="__toc">
      <span class="md-nav__icon md-icon"></span>
      Table of contents
    </label>
    <ul class="md-nav__list" data-md-component="toc" data-md-scrollfix>
      
        <li class="md-nav__item">
  <a href="#pytorch_widedeep.callbacks.LRHistory" class="md-nav__link">
    <span class="md-ellipsis">
      LRHistory
    </span>
  </a>
  
</li>
      
        <li class="md-nav__item">
  <a href="#pytorch_widedeep.callbacks.ModelCheckpoint" class="md-nav__link">
    <span class="md-ellipsis">
      ModelCheckpoint
    </span>
  </a>
  
</li>
      
        <li class="md-nav__item">
  <a href="#pytorch_widedeep.callbacks.EarlyStopping" class="md-nav__link">
    <span class="md-ellipsis">
      EarlyStopping
    </span>
  </a>
  
</li>
      
    </ul>
  
</nav>
      
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="trainer.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Trainer
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="bayesian_trainer.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Bayesian Trainer
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="self_supervised_pretraining.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Self Supervised Pretraining
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="tab2vec.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Tab2Vec
  </span>
  

      </a>
    </li>
  

              
            
          </ul>
        </nav>
      
    </li>
  

    
      
      
  
  
  
  
    
    
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
    
    
      
      
    
    
    <li class="md-nav__item md-nav__item--nested">
      
        
        
          
        
        <input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_5" >
        
          
          <label class="md-nav__link" for="__nav_5" id="__nav_5_label" tabindex="0">
            
  
  <span class="md-ellipsis">
    Examples
  </span>
  

            <span class="md-nav__icon md-icon"></span>
          </label>
        
        <nav class="md-nav" data-md-level="1" aria-labelledby="__nav_5_label" aria-expanded="false">
          <label class="md-nav__title" for="__nav_5">
            <span class="md-nav__icon md-icon"></span>
            Examples
          </label>
          <ul class="md-nav__list" data-md-scrollfix>
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/01_preprocessors_and_utils.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    01_preprocessors_and_utils
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/02_model_components.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    02_model_components
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/03_binary_classification_with_defaults.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    03_binary_classification_with_defaults
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/04_regression_with_images_and_text.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    04_regression_with_images_and_text
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/05_save_and_load_model_and_artifacts.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    05_save_and_load_model_and_artifacts
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/06_finetune_and_warmup.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    06_finetune_and_warmup
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/07_custom_components.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    07_custom_components
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/08_custom_dataLoader_imbalanced_dataset.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    08_custom_dataLoader_imbalanced_dataset
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/09_extracting_embeddings.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    09_extracting_embeddings
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/10_3rd_party_integration-RayTune_WnB.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    10_3rd_party_integration-RayTune_WnB
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/11_auc_multiclass.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    11_auc_multiclass
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/12_ZILNLoss_origkeras_vs_pytorch_widedeep.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    12_ZILNLoss_origkeras_vs_pytorch_widedeep
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/13_model_uncertainty_prediction.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    13_model_uncertainty_prediction
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/14_bayesian_models.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    14_bayesian_models
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/15_Self_Supervised_Pretraning_pt1.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    15_Self-Supervised Pre-Training pt 1
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/15_Self_Supervised_Pretraning_pt2.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    15_Self-Supervised Pre-Training pt 2
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/16_Usign_a_custom_hugging_face_model.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    16_Usign-a-custom-hugging-face-model
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/17_feature_importance_via_attention_weights.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    17_feature_importance_via_attention_weights
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/18_wide_and_deep_for_recsys_pt1.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    18_wide_and_deep_for_recsys_pt1
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/18_wide_and_deep_for_recsys_pt2.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    18_wide_and_deep_for_recsys_pt2
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/19_load_from_folder_functionality.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    19_load_from_folder_functionality
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/20_Using_huggingface_within_widedeep.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    20-Using-huggingface-within-widedeep
  </span>
  

      </a>
    </li>
  

              
            
          </ul>
        </nav>
      
    </li>
  

    
      
      
  
  
  
  
    <li class="md-nav__item">
      <a href="../contributing.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Contributing
  </span>
  

      </a>
    </li>
  

    
  </ul>
</nav>
                  </div>
                </div>
              </div>
            
            
          
          
            <div class="md-content" data-md-component="content">
              <article class="md-content__inner md-typeset">
                
                  

  
  


<h1 id="callbacks">Callbacks<a class="headerlink" href="#callbacks" title="Permanent link">&para;</a></h1>
<p>Here are the 4 callbacks available to the user in <code>pytorch-widedepp</code>:
<code>LRHistory</code>, <code>ModelCheckpoint</code>, <code>EarlyStopping</code> and <code>RayTuneReporter</code>.</p>
<p><img alt="ℹ️" class="emojione" src="https://cdnjs.cloudflare.com/ajax/libs/emojione/2.2.7/assets/png/2139.png" title=":information_source:" />  <strong>NOTE</strong>: other callbacks , like <code>History</code>, run always
 by default. In particular, the <code>History</code> callback saves the metrics in the
 <code>history</code> attribute of the <code>Trainer</code>.</p>


<div class="doc doc-object doc-class">



<h2 id="pytorch_widedeep.callbacks.LRHistory" class="doc doc-heading">
            <span class="doc doc-object-name doc-class-name">LRHistory</span>


<a href="#pytorch_widedeep.callbacks.LRHistory" class="headerlink" title="Permanent link">&para;</a></h2>


    <div class="doc doc-contents first">
            <p class="doc doc-class-bases">
              Bases: <code><span title="pytorch_widedeep.callbacks.Callback">Callback</span></code></p>


        <p>Saves the learning rates during training in the <code>lr_history</code> attribute
of the <code>Trainer</code>.</p>
<p>Callbacks are passed as input parameters to the <code>Trainer</code> class. See
<code>pytorch_widedeep.trainer.Trainer</code></p>


<p><span class="doc-section-title">Parameters:</span></p>
    <table>
      <thead>
        <tr>
          <th>Name</th>
          <th>Type</th>
          <th>Description</th>
          <th>Default</th>
        </tr>
      </thead>
      <tbody>
          <tr class="doc-section-item">
            <td>
                <code>n_epochs</code>
            </td>
            <td>
                  <code>int</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>number of training epochs</p>
              </div>
            </td>
            <td>
                <em>required</em>
            </td>
          </tr>
      </tbody>
    </table>


<p><span class="doc-section-title">Examples:</span></p>
    <div class="highlight"><pre><span></span><code><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">pytorch_widedeep.callbacks</span> <span class="kn">import</span> <span class="n">LRHistory</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">pytorch_widedeep.models</span> <span class="kn">import</span> <span class="n">TabMlp</span><span class="p">,</span> <span class="n">Wide</span><span class="p">,</span> <span class="n">WideDeep</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">pytorch_widedeep.training</span> <span class="kn">import</span> <span class="n">Trainer</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">embed_input</span> <span class="o">=</span> <span class="p">[(</span><span class="n">u</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">)</span> <span class="k">for</span> <span class="n">u</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">([</span><span class="s2">&quot;a&quot;</span><span class="p">,</span> <span class="s2">&quot;b&quot;</span><span class="p">,</span> <span class="s2">&quot;c&quot;</span><span class="p">][:</span><span class="mi">4</span><span class="p">],</span> <span class="p">[</span><span class="mi">4</span><span class="p">]</span> <span class="o">*</span> <span class="mi">3</span><span class="p">,</span> <span class="p">[</span><span class="mi">8</span><span class="p">]</span> <span class="o">*</span> <span class="mi">3</span><span class="p">)]</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">column_idx</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span> <span class="k">for</span> <span class="n">v</span><span class="p">,</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">([</span><span class="s2">&quot;a&quot;</span><span class="p">,</span> <span class="s2">&quot;b&quot;</span><span class="p">,</span> <span class="s2">&quot;c&quot;</span><span class="p">])}</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">wide</span> <span class="o">=</span> <span class="n">Wide</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">deep</span> <span class="o">=</span> <span class="n">TabMlp</span><span class="p">(</span><span class="n">mlp_hidden_dims</span><span class="o">=</span><span class="p">[</span><span class="mi">8</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span> <span class="n">column_idx</span><span class="o">=</span><span class="n">column_idx</span><span class="p">,</span> <span class="n">cat_embed_input</span><span class="o">=</span><span class="n">embed_input</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">WideDeep</span><span class="p">(</span><span class="n">wide</span><span class="p">,</span> <span class="n">deep</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">objective</span><span class="o">=</span><span class="s2">&quot;regression&quot;</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">LRHistory</span><span class="p">(</span><span class="n">n_epochs</span><span class="o">=</span><span class="mi">10</span><span class="p">)])</span>
</code></pre></div>






              <details class="quote">
                <summary>Source code in <code>pytorch_widedeep/callbacks.py</code></summary>
                <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">251</span>
<span class="normal">252</span>
<span class="normal">253</span>
<span class="normal">254</span>
<span class="normal">255</span>
<span class="normal">256</span>
<span class="normal">257</span>
<span class="normal">258</span>
<span class="normal">259</span>
<span class="normal">260</span>
<span class="normal">261</span>
<span class="normal">262</span>
<span class="normal">263</span>
<span class="normal">264</span>
<span class="normal">265</span>
<span class="normal">266</span>
<span class="normal">267</span>
<span class="normal">268</span>
<span class="normal">269</span>
<span class="normal">270</span>
<span class="normal">271</span>
<span class="normal">272</span>
<span class="normal">273</span>
<span class="normal">274</span>
<span class="normal">275</span>
<span class="normal">276</span>
<span class="normal">277</span>
<span class="normal">278</span>
<span class="normal">279</span>
<span class="normal">280</span>
<span class="normal">281</span>
<span class="normal">282</span>
<span class="normal">283</span>
<span class="normal">284</span>
<span class="normal">285</span>
<span class="normal">286</span>
<span class="normal">287</span>
<span class="normal">288</span>
<span class="normal">289</span>
<span class="normal">290</span>
<span class="normal">291</span>
<span class="normal">292</span>
<span class="normal">293</span>
<span class="normal">294</span>
<span class="normal">295</span>
<span class="normal">296</span>
<span class="normal">297</span>
<span class="normal">298</span>
<span class="normal">299</span>
<span class="normal">300</span>
<span class="normal">301</span>
<span class="normal">302</span>
<span class="normal">303</span>
<span class="normal">304</span>
<span class="normal">305</span>
<span class="normal">306</span>
<span class="normal">307</span>
<span class="normal">308</span>
<span class="normal">309</span>
<span class="normal">310</span>
<span class="normal">311</span>
<span class="normal">312</span>
<span class="normal">313</span>
<span class="normal">314</span>
<span class="normal">315</span>
<span class="normal">316</span>
<span class="normal">317</span>
<span class="normal">318</span>
<span class="normal">319</span>
<span class="normal">320</span>
<span class="normal">321</span>
<span class="normal">322</span>
<span class="normal">323</span>
<span class="normal">324</span>
<span class="normal">325</span>
<span class="normal">326</span>
<span class="normal">327</span>
<span class="normal">328</span>
<span class="normal">329</span>
<span class="normal">330</span>
<span class="normal">331</span>
<span class="normal">332</span>
<span class="normal">333</span>
<span class="normal">334</span>
<span class="normal">335</span>
<span class="normal">336</span>
<span class="normal">337</span>
<span class="normal">338</span>
<span class="normal">339</span>
<span class="normal">340</span>
<span class="normal">341</span>
<span class="normal">342</span>
<span class="normal">343</span>
<span class="normal">344</span>
<span class="normal">345</span>
<span class="normal">346</span>
<span class="normal">347</span>
<span class="normal">348</span>
<span class="normal">349</span>
<span class="normal">350</span>
<span class="normal">351</span>
<span class="normal">352</span>
<span class="normal">353</span>
<span class="normal">354</span>
<span class="normal">355</span>
<span class="normal">356</span>
<span class="normal">357</span>
<span class="normal">358</span>
<span class="normal">359</span>
<span class="normal">360</span>
<span class="normal">361</span>
<span class="normal">362</span>
<span class="normal">363</span>
<span class="normal">364</span>
<span class="normal">365</span>
<span class="normal">366</span>
<span class="normal">367</span>
<span class="normal">368</span>
<span class="normal">369</span>
<span class="normal">370</span>
<span class="normal">371</span>
<span class="normal">372</span>
<span class="normal">373</span>
<span class="normal">374</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">class</span> <span class="nc">LRHistory</span><span class="p">(</span><span class="n">Callback</span><span class="p">):</span>
<span class="w">    </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;Saves the learning rates during training in the `lr_history` attribute</span>
<span class="sd">    of the `Trainer`.</span>

<span class="sd">    Callbacks are passed as input parameters to the `Trainer` class. See</span>
<span class="sd">    `pytorch_widedeep.trainer.Trainer`</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    n_epochs: int</span>
<span class="sd">        number of training epochs</span>

<span class="sd">    Examples</span>
<span class="sd">    --------</span>
<span class="sd">    &gt;&gt;&gt; from pytorch_widedeep.callbacks import LRHistory</span>
<span class="sd">    &gt;&gt;&gt; from pytorch_widedeep.models import TabMlp, Wide, WideDeep</span>
<span class="sd">    &gt;&gt;&gt; from pytorch_widedeep.training import Trainer</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; embed_input = [(u, i, j) for u, i, j in zip([&quot;a&quot;, &quot;b&quot;, &quot;c&quot;][:4], [4] * 3, [8] * 3)]</span>
<span class="sd">    &gt;&gt;&gt; column_idx = {k: v for v, k in enumerate([&quot;a&quot;, &quot;b&quot;, &quot;c&quot;])}</span>
<span class="sd">    &gt;&gt;&gt; wide = Wide(10, 1)</span>
<span class="sd">    &gt;&gt;&gt; deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, cat_embed_input=embed_input)</span>
<span class="sd">    &gt;&gt;&gt; model = WideDeep(wide, deep)</span>
<span class="sd">    &gt;&gt;&gt; trainer = Trainer(model, objective=&quot;regression&quot;, callbacks=[LRHistory(n_epochs=10)])</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_epochs</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">LRHistory</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">n_epochs</span> <span class="o">=</span> <span class="n">n_epochs</span>

    <span class="k">def</span> <span class="nf">on_epoch_begin</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">logs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
        <span class="k">if</span> <span class="n">epoch</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">lr_scheduler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">lr_history</span> <span class="o">=</span> <span class="p">{}</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_multiple_scheduler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="p">):</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">_save_group_lr_mulitple_scheduler</span><span class="p">(</span><span class="n">step_location</span><span class="o">=</span><span class="s2">&quot;on_epoch_begin&quot;</span><span class="p">)</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">_save_group_lr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">optimizer</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">on_batch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">logs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">lr_scheduler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_multiple_scheduler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="p">):</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">_save_group_lr_mulitple_scheduler</span><span class="p">(</span><span class="n">step_location</span><span class="o">=</span><span class="s2">&quot;on_batch_end&quot;</span><span class="p">)</span>
            <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">cyclic_lr</span><span class="p">:</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">_save_group_lr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">optimizer</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">on_epoch_end</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">logs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">metric</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
    <span class="p">):</span>
        <span class="k">if</span> <span class="n">epoch</span> <span class="o">!=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_epochs</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">lr_scheduler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_multiple_scheduler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="p">):</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">_save_group_lr_mulitple_scheduler</span><span class="p">(</span><span class="n">step_location</span><span class="o">=</span><span class="s2">&quot;on_epoch_end&quot;</span><span class="p">)</span>
            <span class="k">elif</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">cyclic_lr</span><span class="p">:</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">_save_group_lr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">optimizer</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">_save_group_lr_mulitple_scheduler</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step_location</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
        <span class="k">for</span> <span class="n">model_name</span><span class="p">,</span> <span class="n">optimizer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">_optimizers</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
                <span class="c1"># then, if it has schedulers, we assume it has to have the</span>
                <span class="c1"># same number of schedulers as optimizers</span>
                <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">opt</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">optimizer</span><span class="p">):</span>
                    <span class="k">if</span> <span class="p">(</span>
                        <span class="n">step_location</span> <span class="o">==</span> <span class="s2">&quot;on_epoch_begin&quot;</span>
                        <span class="ow">or</span> <span class="p">(</span>
                            <span class="n">step_location</span> <span class="o">==</span> <span class="s2">&quot;on_batch_end&quot;</span>
                            <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">_has_cyclic_scheduler</span><span class="p">(</span><span class="n">model_name</span><span class="p">)</span>
                        <span class="p">)</span>
                        <span class="ow">or</span> <span class="p">(</span>
                            <span class="n">step_location</span> <span class="o">==</span> <span class="s2">&quot;on_epoch_end&quot;</span>
                            <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_has_cyclic_scheduler</span><span class="p">(</span><span class="n">model_name</span><span class="p">)</span>
                        <span class="p">)</span>
                    <span class="p">):</span>
                        <span class="bp">self</span><span class="o">.</span><span class="n">_save_group_lr</span><span class="p">(</span><span class="n">opt</span><span class="p">,</span> <span class="n">model_name</span><span class="p">,</span> <span class="s2">&quot;_&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="s2">&quot;opt&quot;</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">i</span><span class="p">)]))</span>
                <span class="k">else</span><span class="p">:</span>
                    <span class="c1"># do nothing</span>
                    <span class="k">pass</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="k">if</span> <span class="p">(</span>
                    <span class="n">step_location</span> <span class="o">==</span> <span class="s2">&quot;on_epoch_begin&quot;</span>
                    <span class="ow">or</span> <span class="p">(</span>
                        <span class="n">step_location</span> <span class="o">==</span> <span class="s2">&quot;on_batch_end&quot;</span>
                        <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">_has_cyclic_scheduler</span><span class="p">(</span><span class="n">model_name</span><span class="p">)</span>
                    <span class="p">)</span>
                    <span class="ow">or</span> <span class="p">(</span>
                        <span class="n">step_location</span> <span class="o">==</span> <span class="s2">&quot;on_epoch_end&quot;</span>
                        <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_has_cyclic_scheduler</span><span class="p">(</span><span class="n">model_name</span><span class="p">)</span>
                    <span class="p">)</span>
                <span class="p">):</span>
                    <span class="bp">self</span><span class="o">.</span><span class="n">_save_group_lr</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">model_name</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">_save_group_lr</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">opt</span><span class="p">:</span> <span class="n">Optimizer</span><span class="p">,</span>
        <span class="n">suffix</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
        <span class="n">model_name</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    <span class="p">):</span>
        <span class="n">suffix</span> <span class="o">=</span> <span class="n">suffix</span> <span class="ow">or</span> <span class="s2">&quot;&quot;</span>
        <span class="n">model_name</span> <span class="o">=</span> <span class="n">model_name</span> <span class="ow">or</span> <span class="s2">&quot;&quot;</span>
        <span class="k">for</span> <span class="n">group_idx</span><span class="p">,</span> <span class="n">group</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">opt</span><span class="o">.</span><span class="n">param_groups</span><span class="p">):</span>
            <span class="n">group_name</span> <span class="o">=</span> <span class="p">(</span><span class="s2">&quot;_&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">join</span><span class="p">(</span>
                <span class="p">[</span><span class="n">x</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;lr&quot;</span><span class="p">,</span> <span class="n">model_name</span><span class="p">,</span> <span class="n">suffix</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">group_idx</span><span class="p">)]</span> <span class="k">if</span> <span class="n">x</span><span class="p">]</span>
            <span class="p">)</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">lr_history</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="n">group_name</span><span class="p">,</span> <span class="p">[])</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">group</span><span class="p">[</span><span class="s2">&quot;lr&quot;</span><span class="p">])</span>

    <span class="nd">@staticmethod</span>
    <span class="k">def</span> <span class="nf">_multiple_scheduler</span><span class="p">(</span><span class="n">scheduler</span><span class="p">:</span> <span class="n">LRScheduler</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">scheduler</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;MultipleLRScheduler&quot;</span>

    <span class="k">def</span> <span class="nf">_has_cyclic_scheduler</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
        <span class="k">if</span> <span class="n">model_name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="o">.</span><span class="n">_schedulers</span><span class="p">:</span>
            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="o">.</span><span class="n">_schedulers</span><span class="p">[</span><span class="n">model_name</span><span class="p">],</span> <span class="nb">list</span><span class="p">):</span>
                <span class="k">return</span> <span class="nb">any</span><span class="p">(</span>
                    <span class="p">[</span>
                        <span class="bp">self</span><span class="o">.</span><span class="n">_is_cyclic</span><span class="p">(</span><span class="n">s</span><span class="p">)</span>
                        <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="o">.</span><span class="n">_schedulers</span><span class="p">[</span><span class="n">model_name</span><span class="p">]</span>
                    <span class="p">]</span>
                <span class="p">)</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_is_cyclic</span><span class="p">(</span>
                    <span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="o">.</span><span class="n">_schedulers</span><span class="p">[</span><span class="n">model_name</span><span class="p">]</span>
                <span class="p">)</span>

    <span class="nd">@staticmethod</span>
    <span class="k">def</span> <span class="nf">_is_cyclic</span><span class="p">(</span><span class="n">scheduler</span><span class="p">:</span> <span class="n">LRScheduler</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
        <span class="k">return</span> <span class="s2">&quot;cycl&quot;</span> <span class="ow">in</span> <span class="n">scheduler</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span>
</code></pre></div></td></tr></table></div>
              </details>



  <div class="doc doc-children">











  </div>

    </div>

</div>

<div class="doc doc-object doc-class">



<h2 id="pytorch_widedeep.callbacks.ModelCheckpoint" class="doc doc-heading">
            <span class="doc doc-object-name doc-class-name">ModelCheckpoint</span>


<a href="#pytorch_widedeep.callbacks.ModelCheckpoint" class="headerlink" title="Permanent link">&para;</a></h2>


    <div class="doc doc-contents first">
            <p class="doc doc-class-bases">
              Bases: <code><span title="pytorch_widedeep.callbacks.Callback">Callback</span></code></p>


        <p>Saves the model after every epoch.</p>
<p>This class is almost identical to the corresponding keras class.
Therefore, <strong>credit</strong> to the Keras Team.</p>
<p>Callbacks are passed as input parameters to the <code>Trainer</code> class. See
<code>pytorch_widedeep.trainer.Trainer</code></p>


<p><span class="doc-section-title">Parameters:</span></p>
    <table>
      <thead>
        <tr>
          <th>Name</th>
          <th>Type</th>
          <th>Description</th>
          <th>Default</th>
        </tr>
      </thead>
      <tbody>
          <tr class="doc-section-item">
            <td>
                <code>filepath</code>
            </td>
            <td>
                  <code><span title="pytorch_widedeep.wdtypes.Optional">Optional</span>[str]</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>Full path to save the output weights. It must contain only the root of
the filenames. Epoch number and <code>.pt</code> extension (for pytorch) will be
added. e.g. <code>filepath="path/to/output_weights/weights_out"</code> And the
saved files in that directory will be named:
<em>'weights_out_1.pt', 'weights_out_2.pt', ...</em>. If set to <code>None</code> the
class just report best metric and best_epoch.</p>
              </div>
            </td>
            <td>
                  <code>None</code>
            </td>
          </tr>
          <tr class="doc-section-item">
            <td>
                <code>monitor</code>
            </td>
            <td>
                  <code>str</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>quantity to monitor. Typically <em>'val_loss'</em> or metric name
(e.g. <em>'val_acc'</em>)</p>
              </div>
            </td>
            <td>
                  <code>&#39;val_loss&#39;</code>
            </td>
          </tr>
          <tr class="doc-section-item">
            <td>
                <code>min_delta</code>
            </td>
            <td>
                  <code>float</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>minimum change in the monitored quantity to qualify as an
improvement, i.e. an absolute change of less than min_delta, will
count as no improvement.</p>
              </div>
            </td>
            <td>
                  <code>0.0</code>
            </td>
          </tr>
          <tr class="doc-section-item">
            <td>
                <code>verbose</code>
            </td>
            <td>
                  <code>int</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>verbosity mode</p>
              </div>
            </td>
            <td>
                  <code>0</code>
            </td>
          </tr>
          <tr class="doc-section-item">
            <td>
                <code>save_best_only</code>
            </td>
            <td>
                  <code>bool</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>the latest best model according to the quantity monitored will not be
overwritten.</p>
              </div>
            </td>
            <td>
                  <code>False</code>
            </td>
          </tr>
          <tr class="doc-section-item">
            <td>
                <code>mode</code>
            </td>
            <td>
                  <code>str</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>If <code>save_best_only=True</code>, the decision to overwrite the current save
file is made based on either the maximization or the minimization of
the monitored quantity. For <em>'acc'</em>, this should be <em>'max'</em>, for
<em>'loss'</em> this should be <em>'min'</em>, etc. In '<em>auto'</em> mode, the
direction is automatically inferred from the name of the monitored
quantity.</p>
              </div>
            </td>
            <td>
                  <code>&#39;auto&#39;</code>
            </td>
          </tr>
          <tr class="doc-section-item">
            <td>
                <code>period</code>
            </td>
            <td>
                  <code>int</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>Interval (number of epochs) between checkpoints.</p>
              </div>
            </td>
            <td>
                  <code>1</code>
            </td>
          </tr>
          <tr class="doc-section-item">
            <td>
                <code>max_save</code>
            </td>
            <td>
                  <code>int</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>Maximum number of outputs to save. If -1 will save all outputs</p>
              </div>
            </td>
            <td>
                  <code>-1</code>
            </td>
          </tr>
      </tbody>
    </table>


<p><span class="doc-section-title">Attributes:</span></p>
    <table>
      <thead>
        <tr>
          <th>Name</th>
          <th>Type</th>
          <th>Description</th>
        </tr>
      </thead>
      <tbody>
          <tr class="doc-section-item">
            <td><code><span title="pytorch_widedeep.callbacks.ModelCheckpoint.best">best</span></code></td>
            <td>
                  <code>float</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>best metric</p>
              </div>
            </td>
          </tr>
          <tr class="doc-section-item">
            <td><code><span title="pytorch_widedeep.callbacks.ModelCheckpoint.best_epoch">best_epoch</span></code></td>
            <td>
                  <code>int</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>best epoch</p>
              </div>
            </td>
          </tr>
          <tr class="doc-section-item">
            <td><code><span title="pytorch_widedeep.callbacks.ModelCheckpoint.best_state_dict">best_state_dict</span></code></td>
            <td>
                  <code>dict</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>best model state dictionary.<br/>
To restore model to its best state use <code>Trainer.model.load_state_dict
(model_checkpoint.best_state_dict)</code> where <code>model_checkpoint</code> is an
instance of the class <code>ModelCheckpoint</code>. See the Examples folder in
the repo or the Examples section in this documentation for details</p>
              </div>
            </td>
          </tr>
      </tbody>
    </table>


<p><span class="doc-section-title">Examples:</span></p>
    <div class="highlight"><pre><span></span><code><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">pytorch_widedeep.callbacks</span> <span class="kn">import</span> <span class="n">ModelCheckpoint</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">pytorch_widedeep.models</span> <span class="kn">import</span> <span class="n">TabMlp</span><span class="p">,</span> <span class="n">Wide</span><span class="p">,</span> <span class="n">WideDeep</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">pytorch_widedeep.training</span> <span class="kn">import</span> <span class="n">Trainer</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">embed_input</span> <span class="o">=</span> <span class="p">[(</span><span class="n">u</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">)</span> <span class="k">for</span> <span class="n">u</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">([</span><span class="s2">&quot;a&quot;</span><span class="p">,</span> <span class="s2">&quot;b&quot;</span><span class="p">,</span> <span class="s2">&quot;c&quot;</span><span class="p">][:</span><span class="mi">4</span><span class="p">],</span> <span class="p">[</span><span class="mi">4</span><span class="p">]</span> <span class="o">*</span> <span class="mi">3</span><span class="p">,</span> <span class="p">[</span><span class="mi">8</span><span class="p">]</span> <span class="o">*</span> <span class="mi">3</span><span class="p">)]</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">column_idx</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span> <span class="k">for</span> <span class="n">v</span><span class="p">,</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">([</span><span class="s2">&quot;a&quot;</span><span class="p">,</span> <span class="s2">&quot;b&quot;</span><span class="p">,</span> <span class="s2">&quot;c&quot;</span><span class="p">])}</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">wide</span> <span class="o">=</span> <span class="n">Wide</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">deep</span> <span class="o">=</span> <span class="n">TabMlp</span><span class="p">(</span><span class="n">mlp_hidden_dims</span><span class="o">=</span><span class="p">[</span><span class="mi">8</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span> <span class="n">column_idx</span><span class="o">=</span><span class="n">column_idx</span><span class="p">,</span> <span class="n">cat_embed_input</span><span class="o">=</span><span class="n">embed_input</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">WideDeep</span><span class="p">(</span><span class="n">wide</span><span class="p">,</span> <span class="n">deep</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">objective</span><span class="o">=</span><span class="s2">&quot;regression&quot;</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">ModelCheckpoint</span><span class="p">(</span><span class="n">filepath</span><span class="o">=</span><span class="s1">&#39;checkpoints/weights_out&#39;</span><span class="p">)])</span>
</code></pre></div>






              <details class="quote">
                <summary>Source code in <code>pytorch_widedeep/callbacks.py</code></summary>
                <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">377</span>
<span class="normal">378</span>
<span class="normal">379</span>
<span class="normal">380</span>
<span class="normal">381</span>
<span class="normal">382</span>
<span class="normal">383</span>
<span class="normal">384</span>
<span class="normal">385</span>
<span class="normal">386</span>
<span class="normal">387</span>
<span class="normal">388</span>
<span class="normal">389</span>
<span class="normal">390</span>
<span class="normal">391</span>
<span class="normal">392</span>
<span class="normal">393</span>
<span class="normal">394</span>
<span class="normal">395</span>
<span class="normal">396</span>
<span class="normal">397</span>
<span class="normal">398</span>
<span class="normal">399</span>
<span class="normal">400</span>
<span class="normal">401</span>
<span class="normal">402</span>
<span class="normal">403</span>
<span class="normal">404</span>
<span class="normal">405</span>
<span class="normal">406</span>
<span class="normal">407</span>
<span class="normal">408</span>
<span class="normal">409</span>
<span class="normal">410</span>
<span class="normal">411</span>
<span class="normal">412</span>
<span class="normal">413</span>
<span class="normal">414</span>
<span class="normal">415</span>
<span class="normal">416</span>
<span class="normal">417</span>
<span class="normal">418</span>
<span class="normal">419</span>
<span class="normal">420</span>
<span class="normal">421</span>
<span class="normal">422</span>
<span class="normal">423</span>
<span class="normal">424</span>
<span class="normal">425</span>
<span class="normal">426</span>
<span class="normal">427</span>
<span class="normal">428</span>
<span class="normal">429</span>
<span class="normal">430</span>
<span class="normal">431</span>
<span class="normal">432</span>
<span class="normal">433</span>
<span class="normal">434</span>
<span class="normal">435</span>
<span class="normal">436</span>
<span class="normal">437</span>
<span class="normal">438</span>
<span class="normal">439</span>
<span class="normal">440</span>
<span class="normal">441</span>
<span class="normal">442</span>
<span class="normal">443</span>
<span class="normal">444</span>
<span class="normal">445</span>
<span class="normal">446</span>
<span class="normal">447</span>
<span class="normal">448</span>
<span class="normal">449</span>
<span class="normal">450</span>
<span class="normal">451</span>
<span class="normal">452</span>
<span class="normal">453</span>
<span class="normal">454</span>
<span class="normal">455</span>
<span class="normal">456</span>
<span class="normal">457</span>
<span class="normal">458</span>
<span class="normal">459</span>
<span class="normal">460</span>
<span class="normal">461</span>
<span class="normal">462</span>
<span class="normal">463</span>
<span class="normal">464</span>
<span class="normal">465</span>
<span class="normal">466</span>
<span class="normal">467</span>
<span class="normal">468</span>
<span class="normal">469</span>
<span class="normal">470</span>
<span class="normal">471</span>
<span class="normal">472</span>
<span class="normal">473</span>
<span class="normal">474</span>
<span class="normal">475</span>
<span class="normal">476</span>
<span class="normal">477</span>
<span class="normal">478</span>
<span class="normal">479</span>
<span class="normal">480</span>
<span class="normal">481</span>
<span class="normal">482</span>
<span class="normal">483</span>
<span class="normal">484</span>
<span class="normal">485</span>
<span class="normal">486</span>
<span class="normal">487</span>
<span class="normal">488</span>
<span class="normal">489</span>
<span class="normal">490</span>
<span class="normal">491</span>
<span class="normal">492</span>
<span class="normal">493</span>
<span class="normal">494</span>
<span class="normal">495</span>
<span class="normal">496</span>
<span class="normal">497</span>
<span class="normal">498</span>
<span class="normal">499</span>
<span class="normal">500</span>
<span class="normal">501</span>
<span class="normal">502</span>
<span class="normal">503</span>
<span class="normal">504</span>
<span class="normal">505</span>
<span class="normal">506</span>
<span class="normal">507</span>
<span class="normal">508</span>
<span class="normal">509</span>
<span class="normal">510</span>
<span class="normal">511</span>
<span class="normal">512</span>
<span class="normal">513</span>
<span class="normal">514</span>
<span class="normal">515</span>
<span class="normal">516</span>
<span class="normal">517</span>
<span class="normal">518</span>
<span class="normal">519</span>
<span class="normal">520</span>
<span class="normal">521</span>
<span class="normal">522</span>
<span class="normal">523</span>
<span class="normal">524</span>
<span class="normal">525</span>
<span class="normal">526</span>
<span class="normal">527</span>
<span class="normal">528</span>
<span class="normal">529</span>
<span class="normal">530</span>
<span class="normal">531</span>
<span class="normal">532</span>
<span class="normal">533</span>
<span class="normal">534</span>
<span class="normal">535</span>
<span class="normal">536</span>
<span class="normal">537</span>
<span class="normal">538</span>
<span class="normal">539</span>
<span class="normal">540</span>
<span class="normal">541</span>
<span class="normal">542</span>
<span class="normal">543</span>
<span class="normal">544</span>
<span class="normal">545</span>
<span class="normal">546</span>
<span class="normal">547</span>
<span class="normal">548</span>
<span class="normal">549</span>
<span class="normal">550</span>
<span class="normal">551</span>
<span class="normal">552</span>
<span class="normal">553</span>
<span class="normal">554</span>
<span class="normal">555</span>
<span class="normal">556</span>
<span class="normal">557</span>
<span class="normal">558</span>
<span class="normal">559</span>
<span class="normal">560</span>
<span class="normal">561</span>
<span class="normal">562</span>
<span class="normal">563</span>
<span class="normal">564</span>
<span class="normal">565</span>
<span class="normal">566</span>
<span class="normal">567</span>
<span class="normal">568</span>
<span class="normal">569</span>
<span class="normal">570</span>
<span class="normal">571</span>
<span class="normal">572</span>
<span class="normal">573</span>
<span class="normal">574</span>
<span class="normal">575</span>
<span class="normal">576</span>
<span class="normal">577</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">class</span> <span class="nc">ModelCheckpoint</span><span class="p">(</span><span class="n">Callback</span><span class="p">):</span>
<span class="w">    </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;Saves the model after every epoch.</span>

<span class="sd">    This class is almost identical to the corresponding keras class.</span>
<span class="sd">    Therefore, **credit** to the Keras Team.</span>

<span class="sd">    Callbacks are passed as input parameters to the `Trainer` class. See</span>
<span class="sd">    `pytorch_widedeep.trainer.Trainer`</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    filepath: str, default=None</span>
<span class="sd">        Full path to save the output weights. It must contain only the root of</span>
<span class="sd">        the filenames. Epoch number and `.pt` extension (for pytorch) will be</span>
<span class="sd">        added. e.g. `filepath=&quot;path/to/output_weights/weights_out&quot;` And the</span>
<span class="sd">        saved files in that directory will be named:</span>
<span class="sd">        _&#39;weights_out_1.pt&#39;, &#39;weights_out_2.pt&#39;, ..._. If set to `None` the</span>
<span class="sd">        class just report best metric and best_epoch.</span>
<span class="sd">    monitor: str, default=&quot;loss&quot;</span>
<span class="sd">        quantity to monitor. Typically _&#39;val_loss&#39;_ or metric name</span>
<span class="sd">        (e.g. _&#39;val_acc&#39;_)</span>
<span class="sd">    min_delta: float, default=0.</span>
<span class="sd">        minimum change in the monitored quantity to qualify as an</span>
<span class="sd">        improvement, i.e. an absolute change of less than min_delta, will</span>
<span class="sd">        count as no improvement.</span>
<span class="sd">    verbose:int, default=0</span>
<span class="sd">        verbosity mode</span>
<span class="sd">    save_best_only: bool, default=False,</span>
<span class="sd">        the latest best model according to the quantity monitored will not be</span>
<span class="sd">        overwritten.</span>
<span class="sd">    mode: str, default=&quot;auto&quot;</span>
<span class="sd">        If `save_best_only=True`, the decision to overwrite the current save</span>
<span class="sd">        file is made based on either the maximization or the minimization of</span>
<span class="sd">        the monitored quantity. For _&#39;acc&#39;_, this should be _&#39;max&#39;_, for</span>
<span class="sd">        _&#39;loss&#39;_ this should be _&#39;min&#39;_, etc. In &#39;_auto&#39;_ mode, the</span>
<span class="sd">        direction is automatically inferred from the name of the monitored</span>
<span class="sd">        quantity.</span>
<span class="sd">    period: int, default=1</span>
<span class="sd">        Interval (number of epochs) between checkpoints.</span>
<span class="sd">    max_save: int, default=-1</span>
<span class="sd">        Maximum number of outputs to save. If -1 will save all outputs</span>

<span class="sd">    Attributes</span>
<span class="sd">    ----------</span>
<span class="sd">    best: float</span>
<span class="sd">        best metric</span>
<span class="sd">    best_epoch: int</span>
<span class="sd">        best epoch</span>
<span class="sd">    best_state_dict: dict</span>
<span class="sd">        best model state dictionary.&lt;br/&gt;</span>
<span class="sd">        To restore model to its best state use `Trainer.model.load_state_dict</span>
<span class="sd">        (model_checkpoint.best_state_dict)` where `model_checkpoint` is an</span>
<span class="sd">        instance of the class `ModelCheckpoint`. See the Examples folder in</span>
<span class="sd">        the repo or the Examples section in this documentation for details</span>

<span class="sd">    Examples</span>
<span class="sd">    --------</span>
<span class="sd">    &gt;&gt;&gt; from pytorch_widedeep.callbacks import ModelCheckpoint</span>
<span class="sd">    &gt;&gt;&gt; from pytorch_widedeep.models import TabMlp, Wide, WideDeep</span>
<span class="sd">    &gt;&gt;&gt; from pytorch_widedeep.training import Trainer</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; embed_input = [(u, i, j) for u, i, j in zip([&quot;a&quot;, &quot;b&quot;, &quot;c&quot;][:4], [4] * 3, [8] * 3)]</span>
<span class="sd">    &gt;&gt;&gt; column_idx = {k: v for v, k in enumerate([&quot;a&quot;, &quot;b&quot;, &quot;c&quot;])}</span>
<span class="sd">    &gt;&gt;&gt; wide = Wide(10, 1)</span>
<span class="sd">    &gt;&gt;&gt; deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, cat_embed_input=embed_input)</span>
<span class="sd">    &gt;&gt;&gt; model = WideDeep(wide, deep)</span>
<span class="sd">    &gt;&gt;&gt; trainer = Trainer(model, objective=&quot;regression&quot;, callbacks=[ModelCheckpoint(filepath=&#39;checkpoints/weights_out&#39;)])</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">filepath</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
        <span class="n">monitor</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;val_loss&quot;</span><span class="p">,</span>
        <span class="n">min_delta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
        <span class="n">verbose</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
        <span class="n">save_best_only</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
        <span class="n">mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;auto&quot;</span><span class="p">,</span>
        <span class="n">period</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
        <span class="n">max_save</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
    <span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">ModelCheckpoint</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">filepath</span> <span class="o">=</span> <span class="n">filepath</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">monitor</span> <span class="o">=</span> <span class="n">monitor</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">min_delta</span> <span class="o">=</span> <span class="n">min_delta</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">verbose</span> <span class="o">=</span> <span class="n">verbose</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">save_best_only</span> <span class="o">=</span> <span class="n">save_best_only</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">period</span> <span class="o">=</span> <span class="n">period</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">max_save</span> <span class="o">=</span> <span class="n">max_save</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">epochs_since_last_save</span> <span class="o">=</span> <span class="mi">0</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">filepath</span><span class="p">:</span>
            <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filepath</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;/&quot;</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
                    <span class="s2">&quot;&#39;filepath&#39; must be the full path to save the output weights,&quot;</span>
                    <span class="s2">&quot; including the root of the filenames. e.g. &#39;checkpoints/weights_out&#39;&quot;</span>
                <span class="p">)</span>

            <span class="n">root_dir</span> <span class="o">=</span> <span class="p">(</span><span class="s2">&quot;/&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filepath</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;/&quot;</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
            <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">root_dir</span><span class="p">):</span>
                <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">root_dir</span><span class="p">)</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_save</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">old_files</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;auto&quot;</span><span class="p">,</span> <span class="s2">&quot;min&quot;</span><span class="p">,</span> <span class="s2">&quot;max&quot;</span><span class="p">]:</span>
            <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span>
                <span class="s2">&quot;ModelCheckpoint mode </span><span class="si">%s</span><span class="s2"> is unknown, &quot;</span>
                <span class="s2">&quot;fallback to auto mode.&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">),</span>
                <span class="ne">RuntimeWarning</span><span class="p">,</span>
            <span class="p">)</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="s2">&quot;auto&quot;</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;min&quot;</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">less</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">best</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">Inf</span>
        <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;max&quot;</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">greater</span>  <span class="c1"># type: ignore[assignment]</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">best</span> <span class="o">=</span> <span class="o">-</span><span class="n">np</span><span class="o">.</span><span class="n">Inf</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">if</span> <span class="n">_is_metric</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor</span><span class="p">):</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">greater</span>  <span class="c1"># type: ignore[assignment]</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">best</span> <span class="o">=</span> <span class="o">-</span><span class="n">np</span><span class="o">.</span><span class="n">Inf</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">less</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">best</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">Inf</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">greater</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">min_delta</span> <span class="o">*=</span> <span class="mi">1</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">min_delta</span> <span class="o">*=</span> <span class="o">-</span><span class="mi">1</span>

    <span class="k">def</span> <span class="nf">on_epoch_end</span><span class="p">(</span>  <span class="c1"># noqa: C901</span>
        <span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">logs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">metric</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
    <span class="p">):</span>
        <span class="n">logs</span> <span class="o">=</span> <span class="n">logs</span> <span class="ow">or</span> <span class="p">{}</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">epochs_since_last_save</span> <span class="o">+=</span> <span class="mi">1</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">epochs_since_last_save</span> <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">period</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">epochs_since_last_save</span> <span class="o">=</span> <span class="mi">0</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">filepath</span><span class="p">:</span>
                <span class="n">filepath</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="si">{}</span><span class="s2">_</span><span class="si">{}</span><span class="s2">.p&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filepath</span><span class="p">,</span> <span class="n">epoch</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">save_best_only</span><span class="p">:</span>
                <span class="n">current</span> <span class="o">=</span> <span class="n">logs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor</span><span class="p">)</span>
                <span class="k">if</span> <span class="n">current</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
                    <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span>
                        <span class="s2">&quot;Can save best model only with </span><span class="si">%s</span><span class="s2"> available, &quot;</span>
                        <span class="s2">&quot;skipping.&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor</span><span class="p">),</span>
                        <span class="ne">RuntimeWarning</span><span class="p">,</span>
                    <span class="p">)</span>
                <span class="k">else</span><span class="p">:</span>
                    <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span><span class="p">(</span><span class="n">current</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_delta</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">best</span><span class="p">):</span>
                        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbose</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">filepath</span><span class="p">:</span>
                                <span class="nb">print</span><span class="p">(</span>
                                    <span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">1</span><span class="si">}</span><span class="s2">: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor</span><span class="si">}</span><span class="s2"> improved from </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">best</span><span class="si">:</span><span class="s2">.5f</span><span class="si">}</span><span class="s2"> to </span><span class="si">{</span><span class="n">current</span><span class="si">:</span><span class="s2">.5f</span><span class="si">}</span><span class="s2"> &quot;</span>
                                    <span class="sa">f</span><span class="s2">&quot;Saving model to </span><span class="si">{</span><span class="n">filepath</span><span class="si">}</span><span class="s2">&quot;</span>
                                <span class="p">)</span>
                            <span class="k">else</span><span class="p">:</span>
                                <span class="nb">print</span><span class="p">(</span>
                                    <span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">1</span><span class="si">}</span><span class="s2">: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor</span><span class="si">}</span><span class="s2"> improved from </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">best</span><span class="si">:</span><span class="s2">.5f</span><span class="si">}</span><span class="s2"> to </span><span class="si">{</span><span class="n">current</span><span class="si">:</span><span class="s2">.5f</span><span class="si">}</span><span class="s2"> &quot;</span>
                                <span class="p">)</span>
                        <span class="bp">self</span><span class="o">.</span><span class="n">best</span> <span class="o">=</span> <span class="n">current</span>
                        <span class="bp">self</span><span class="o">.</span><span class="n">best_epoch</span> <span class="o">=</span> <span class="n">epoch</span>
                        <span class="bp">self</span><span class="o">.</span><span class="n">best_state_dict</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">())</span>
                        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">filepath</span><span class="p">:</span>
                            <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">best_state_dict</span><span class="p">,</span> <span class="n">filepath</span><span class="p">)</span>
                            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_save</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                                <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">old_files</span><span class="p">)</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_save</span><span class="p">:</span>
                                    <span class="k">try</span><span class="p">:</span>
                                        <span class="n">os</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">old_files</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
                                    <span class="k">except</span> <span class="ne">FileNotFoundError</span><span class="p">:</span>
                                        <span class="k">pass</span>
                                    <span class="bp">self</span><span class="o">.</span><span class="n">old_files</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">old_files</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span>
                                <span class="bp">self</span><span class="o">.</span><span class="n">old_files</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">filepath</span><span class="p">)</span>
                    <span class="k">else</span><span class="p">:</span>
                        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbose</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                            <span class="nb">print</span><span class="p">(</span>
                                <span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">1</span><span class="si">}</span><span class="s2">: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor</span><span class="si">}</span><span class="s2"> did not improve from </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">best</span><span class="si">:</span><span class="s2">.5f</span><span class="si">}</span><span class="s2"> &quot;</span>
                                <span class="sa">f</span><span class="s2">&quot; considering a &#39;min_delta&#39; improvement of </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">min_delta</span><span class="si">:</span><span class="s2">.5f</span><span class="si">}</span><span class="s2">&quot;</span>
                            <span class="p">)</span>
            <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">save_best_only</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">filepath</span><span class="p">:</span>
                <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbose</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                    <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Epoch </span><span class="si">%05d</span><span class="s2">: saving model to </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="n">epoch</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">filepath</span><span class="p">))</span>
                <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">filepath</span><span class="p">)</span>
                <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_save</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                    <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">old_files</span><span class="p">)</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_save</span><span class="p">:</span>
                        <span class="k">try</span><span class="p">:</span>
                            <span class="n">os</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">old_files</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
                        <span class="k">except</span> <span class="ne">FileNotFoundError</span><span class="p">:</span>
                            <span class="k">pass</span>
                        <span class="bp">self</span><span class="o">.</span><span class="n">old_files</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">old_files</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span>
                    <span class="bp">self</span><span class="o">.</span><span class="n">old_files</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">filepath</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">__getstate__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="n">d</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span>
        <span class="n">self_dict</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">d</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">d</span> <span class="k">if</span> <span class="n">k</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;trainer&quot;</span><span class="p">,</span> <span class="s2">&quot;model&quot;</span><span class="p">]}</span>
        <span class="k">return</span> <span class="n">self_dict</span>

    <span class="k">def</span> <span class="nf">__setstate__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">):</span>
        <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span> <span class="o">=</span> <span class="n">state</span>
</code></pre></div></td></tr></table></div>
              </details>



  <div class="doc doc-children">











  </div>

    </div>

</div>

<div class="doc doc-object doc-class">



<h2 id="pytorch_widedeep.callbacks.EarlyStopping" class="doc doc-heading">
            <span class="doc doc-object-name doc-class-name">EarlyStopping</span>


<a href="#pytorch_widedeep.callbacks.EarlyStopping" class="headerlink" title="Permanent link">&para;</a></h2>


    <div class="doc doc-contents first">
            <p class="doc doc-class-bases">
              Bases: <code><span title="pytorch_widedeep.callbacks.Callback">Callback</span></code></p>


        <p>Stop training when a monitored quantity has stopped improving.</p>
<p>This class is almost identical to the corresponding keras class.
Therefore, <strong>credit</strong> to the Keras Team.</p>
<p>Callbacks are passed as input parameters to the <code>Trainer</code> class. See
<code>pytorch_widedeep.trainer.Trainer</code></p>


<p><span class="doc-section-title">Parameters:</span></p>
    <table>
      <thead>
        <tr>
          <th>Name</th>
          <th>Type</th>
          <th>Description</th>
          <th>Default</th>
        </tr>
      </thead>
      <tbody>
          <tr class="doc-section-item">
            <td>
                <code>monitor</code>
            </td>
            <td>
                  <code>str</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>Quantity to monitor. Typically <em>'val_loss'</em> or metric name
(e.g. <em>'val_acc'</em>)</p>
              </div>
            </td>
            <td>
                  <code>&#39;val_loss&#39;</code>
            </td>
          </tr>
          <tr class="doc-section-item">
            <td>
                <code>min_delta</code>
            </td>
            <td>
                  <code>float</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>minimum change in the monitored quantity to qualify as an
improvement, i.e. an absolute change of less than min_delta, will
count as no improvement.</p>
              </div>
            </td>
            <td>
                  <code>0.0</code>
            </td>
          </tr>
          <tr class="doc-section-item">
            <td>
                <code>patience</code>
            </td>
            <td>
                  <code>int</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>Number of epochs that produced the monitored quantity with no
improvement after which training will be stopped.</p>
              </div>
            </td>
            <td>
                  <code>10</code>
            </td>
          </tr>
          <tr class="doc-section-item">
            <td>
                <code>verbose</code>
            </td>
            <td>
                  <code>int</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>verbosity mode.</p>
              </div>
            </td>
            <td>
                  <code>0</code>
            </td>
          </tr>
          <tr class="doc-section-item">
            <td>
                <code>mode</code>
            </td>
            <td>
                  <code>str</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>one of <em>{'auto', 'min', 'max'}</em>. In <em>'min'</em> mode, training will
stop when the quantity monitored has stopped decreasing; in <em>'max'</em>
mode it will stop when the quantity monitored has stopped increasing;
in <em>'auto'</em> mode, the direction is automatically inferred from the
name of the monitored quantity.</p>
              </div>
            </td>
            <td>
                  <code>&#39;auto&#39;</code>
            </td>
          </tr>
          <tr class="doc-section-item">
            <td>
                <code>baseline</code>
            </td>
            <td>
                  <code><span title="pytorch_widedeep.wdtypes.Optional">Optional</span>[float]</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>Baseline value for the monitored quantity to reach. Training will
stop if the model does not show improvement over the baseline.</p>
              </div>
            </td>
            <td>
                  <code>None</code>
            </td>
          </tr>
          <tr class="doc-section-item">
            <td>
                <code>restore_best_weights</code>
            </td>
            <td>
                  <code>bool</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>Whether to restore model weights from the epoch with the best
value of the monitored quantity. If <code>False</code>, the model weights
obtained at the last step of training are used.</p>
              </div>
            </td>
            <td>
                  <code>False</code>
            </td>
          </tr>
      </tbody>
    </table>


<p><span class="doc-section-title">Attributes:</span></p>
    <table>
      <thead>
        <tr>
          <th>Name</th>
          <th>Type</th>
          <th>Description</th>
        </tr>
      </thead>
      <tbody>
          <tr class="doc-section-item">
            <td><code><span title="pytorch_widedeep.callbacks.EarlyStopping.best">best</span></code></td>
            <td>
                  <code>float</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>best metric</p>
              </div>
            </td>
          </tr>
          <tr class="doc-section-item">
            <td><code><span title="pytorch_widedeep.callbacks.EarlyStopping.stopped_epoch">stopped_epoch</span></code></td>
            <td>
                  <code>int</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>epoch when the training stopped</p>
              </div>
            </td>
          </tr>
      </tbody>
    </table>


<p><span class="doc-section-title">Examples:</span></p>
    <div class="highlight"><pre><span></span><code><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">pytorch_widedeep.callbacks</span> <span class="kn">import</span> <span class="n">EarlyStopping</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">pytorch_widedeep.models</span> <span class="kn">import</span> <span class="n">TabMlp</span><span class="p">,</span> <span class="n">Wide</span><span class="p">,</span> <span class="n">WideDeep</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">pytorch_widedeep.training</span> <span class="kn">import</span> <span class="n">Trainer</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">embed_input</span> <span class="o">=</span> <span class="p">[(</span><span class="n">u</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">)</span> <span class="k">for</span> <span class="n">u</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">([</span><span class="s2">&quot;a&quot;</span><span class="p">,</span> <span class="s2">&quot;b&quot;</span><span class="p">,</span> <span class="s2">&quot;c&quot;</span><span class="p">][:</span><span class="mi">4</span><span class="p">],</span> <span class="p">[</span><span class="mi">4</span><span class="p">]</span> <span class="o">*</span> <span class="mi">3</span><span class="p">,</span> <span class="p">[</span><span class="mi">8</span><span class="p">]</span> <span class="o">*</span> <span class="mi">3</span><span class="p">)]</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">column_idx</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span> <span class="k">for</span> <span class="n">v</span><span class="p">,</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">([</span><span class="s2">&quot;a&quot;</span><span class="p">,</span> <span class="s2">&quot;b&quot;</span><span class="p">,</span> <span class="s2">&quot;c&quot;</span><span class="p">])}</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">wide</span> <span class="o">=</span> <span class="n">Wide</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">deep</span> <span class="o">=</span> <span class="n">TabMlp</span><span class="p">(</span><span class="n">mlp_hidden_dims</span><span class="o">=</span><span class="p">[</span><span class="mi">8</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span> <span class="n">column_idx</span><span class="o">=</span><span class="n">column_idx</span><span class="p">,</span> <span class="n">cat_embed_input</span><span class="o">=</span><span class="n">embed_input</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">WideDeep</span><span class="p">(</span><span class="n">wide</span><span class="p">,</span> <span class="n">deep</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">objective</span><span class="o">=</span><span class="s2">&quot;regression&quot;</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">EarlyStopping</span><span class="p">(</span><span class="n">patience</span><span class="o">=</span><span class="mi">10</span><span class="p">)])</span>
</code></pre></div>






              <details class="quote">
                <summary>Source code in <code>pytorch_widedeep/callbacks.py</code></summary>
                <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">580</span>
<span class="normal">581</span>
<span class="normal">582</span>
<span class="normal">583</span>
<span class="normal">584</span>
<span class="normal">585</span>
<span class="normal">586</span>
<span class="normal">587</span>
<span class="normal">588</span>
<span class="normal">589</span>
<span class="normal">590</span>
<span class="normal">591</span>
<span class="normal">592</span>
<span class="normal">593</span>
<span class="normal">594</span>
<span class="normal">595</span>
<span class="normal">596</span>
<span class="normal">597</span>
<span class="normal">598</span>
<span class="normal">599</span>
<span class="normal">600</span>
<span class="normal">601</span>
<span class="normal">602</span>
<span class="normal">603</span>
<span class="normal">604</span>
<span class="normal">605</span>
<span class="normal">606</span>
<span class="normal">607</span>
<span class="normal">608</span>
<span class="normal">609</span>
<span class="normal">610</span>
<span class="normal">611</span>
<span class="normal">612</span>
<span class="normal">613</span>
<span class="normal">614</span>
<span class="normal">615</span>
<span class="normal">616</span>
<span class="normal">617</span>
<span class="normal">618</span>
<span class="normal">619</span>
<span class="normal">620</span>
<span class="normal">621</span>
<span class="normal">622</span>
<span class="normal">623</span>
<span class="normal">624</span>
<span class="normal">625</span>
<span class="normal">626</span>
<span class="normal">627</span>
<span class="normal">628</span>
<span class="normal">629</span>
<span class="normal">630</span>
<span class="normal">631</span>
<span class="normal">632</span>
<span class="normal">633</span>
<span class="normal">634</span>
<span class="normal">635</span>
<span class="normal">636</span>
<span class="normal">637</span>
<span class="normal">638</span>
<span class="normal">639</span>
<span class="normal">640</span>
<span class="normal">641</span>
<span class="normal">642</span>
<span class="normal">643</span>
<span class="normal">644</span>
<span class="normal">645</span>
<span class="normal">646</span>
<span class="normal">647</span>
<span class="normal">648</span>
<span class="normal">649</span>
<span class="normal">650</span>
<span class="normal">651</span>
<span class="normal">652</span>
<span class="normal">653</span>
<span class="normal">654</span>
<span class="normal">655</span>
<span class="normal">656</span>
<span class="normal">657</span>
<span class="normal">658</span>
<span class="normal">659</span>
<span class="normal">660</span>
<span class="normal">661</span>
<span class="normal">662</span>
<span class="normal">663</span>
<span class="normal">664</span>
<span class="normal">665</span>
<span class="normal">666</span>
<span class="normal">667</span>
<span class="normal">668</span>
<span class="normal">669</span>
<span class="normal">670</span>
<span class="normal">671</span>
<span class="normal">672</span>
<span class="normal">673</span>
<span class="normal">674</span>
<span class="normal">675</span>
<span class="normal">676</span>
<span class="normal">677</span>
<span class="normal">678</span>
<span class="normal">679</span>
<span class="normal">680</span>
<span class="normal">681</span>
<span class="normal">682</span>
<span class="normal">683</span>
<span class="normal">684</span>
<span class="normal">685</span>
<span class="normal">686</span>
<span class="normal">687</span>
<span class="normal">688</span>
<span class="normal">689</span>
<span class="normal">690</span>
<span class="normal">691</span>
<span class="normal">692</span>
<span class="normal">693</span>
<span class="normal">694</span>
<span class="normal">695</span>
<span class="normal">696</span>
<span class="normal">697</span>
<span class="normal">698</span>
<span class="normal">699</span>
<span class="normal">700</span>
<span class="normal">701</span>
<span class="normal">702</span>
<span class="normal">703</span>
<span class="normal">704</span>
<span class="normal">705</span>
<span class="normal">706</span>
<span class="normal">707</span>
<span class="normal">708</span>
<span class="normal">709</span>
<span class="normal">710</span>
<span class="normal">711</span>
<span class="normal">712</span>
<span class="normal">713</span>
<span class="normal">714</span>
<span class="normal">715</span>
<span class="normal">716</span>
<span class="normal">717</span>
<span class="normal">718</span>
<span class="normal">719</span>
<span class="normal">720</span>
<span class="normal">721</span>
<span class="normal">722</span>
<span class="normal">723</span>
<span class="normal">724</span>
<span class="normal">725</span>
<span class="normal">726</span>
<span class="normal">727</span>
<span class="normal">728</span>
<span class="normal">729</span>
<span class="normal">730</span>
<span class="normal">731</span>
<span class="normal">732</span>
<span class="normal">733</span>
<span class="normal">734</span>
<span class="normal">735</span>
<span class="normal">736</span>
<span class="normal">737</span>
<span class="normal">738</span>
<span class="normal">739</span>
<span class="normal">740</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">class</span> <span class="nc">EarlyStopping</span><span class="p">(</span><span class="n">Callback</span><span class="p">):</span>
<span class="w">    </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;Stop training when a monitored quantity has stopped improving.</span>

<span class="sd">    This class is almost identical to the corresponding keras class.</span>
<span class="sd">    Therefore, **credit** to the Keras Team.</span>

<span class="sd">    Callbacks are passed as input parameters to the `Trainer` class. See</span>
<span class="sd">    `pytorch_widedeep.trainer.Trainer`</span>

<span class="sd">    Parameters</span>
<span class="sd">    -----------</span>
<span class="sd">    monitor: str, default=&#39;val_loss&#39;.</span>
<span class="sd">        Quantity to monitor. Typically _&#39;val_loss&#39;_ or metric name</span>
<span class="sd">        (e.g. _&#39;val_acc&#39;_)</span>
<span class="sd">    min_delta: float, default=0.</span>
<span class="sd">        minimum change in the monitored quantity to qualify as an</span>
<span class="sd">        improvement, i.e. an absolute change of less than min_delta, will</span>
<span class="sd">        count as no improvement.</span>
<span class="sd">    patience: int, default=10.</span>
<span class="sd">        Number of epochs that produced the monitored quantity with no</span>
<span class="sd">        improvement after which training will be stopped.</span>
<span class="sd">    verbose: int.</span>
<span class="sd">        verbosity mode.</span>
<span class="sd">    mode: str, default=&#39;auto&#39;</span>
<span class="sd">        one of _{&#39;auto&#39;, &#39;min&#39;, &#39;max&#39;}_. In _&#39;min&#39;_ mode, training will</span>
<span class="sd">        stop when the quantity monitored has stopped decreasing; in _&#39;max&#39;_</span>
<span class="sd">        mode it will stop when the quantity monitored has stopped increasing;</span>
<span class="sd">        in _&#39;auto&#39;_ mode, the direction is automatically inferred from the</span>
<span class="sd">        name of the monitored quantity.</span>
<span class="sd">    baseline: float, Optional. default=None.</span>
<span class="sd">        Baseline value for the monitored quantity to reach. Training will</span>
<span class="sd">        stop if the model does not show improvement over the baseline.</span>
<span class="sd">    restore_best_weights: bool, default=None</span>
<span class="sd">        Whether to restore model weights from the epoch with the best</span>
<span class="sd">        value of the monitored quantity. If `False`, the model weights</span>
<span class="sd">        obtained at the last step of training are used.</span>

<span class="sd">    Attributes</span>
<span class="sd">    ----------</span>
<span class="sd">    best: float</span>
<span class="sd">        best metric</span>
<span class="sd">    stopped_epoch: int</span>
<span class="sd">        epoch when the training stopped</span>

<span class="sd">    Examples</span>
<span class="sd">    --------</span>
<span class="sd">    &gt;&gt;&gt; from pytorch_widedeep.callbacks import EarlyStopping</span>
<span class="sd">    &gt;&gt;&gt; from pytorch_widedeep.models import TabMlp, Wide, WideDeep</span>
<span class="sd">    &gt;&gt;&gt; from pytorch_widedeep.training import Trainer</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; embed_input = [(u, i, j) for u, i, j in zip([&quot;a&quot;, &quot;b&quot;, &quot;c&quot;][:4], [4] * 3, [8] * 3)]</span>
<span class="sd">    &gt;&gt;&gt; column_idx = {k: v for v, k in enumerate([&quot;a&quot;, &quot;b&quot;, &quot;c&quot;])}</span>
<span class="sd">    &gt;&gt;&gt; wide = Wide(10, 1)</span>
<span class="sd">    &gt;&gt;&gt; deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, cat_embed_input=embed_input)</span>
<span class="sd">    &gt;&gt;&gt; model = WideDeep(wide, deep)</span>
<span class="sd">    &gt;&gt;&gt; trainer = Trainer(model, objective=&quot;regression&quot;, callbacks=[EarlyStopping(patience=10)])</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">monitor</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;val_loss&quot;</span><span class="p">,</span>
        <span class="n">min_delta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
        <span class="n">patience</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
        <span class="n">verbose</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
        <span class="n">mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;auto&quot;</span><span class="p">,</span>
        <span class="n">baseline</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
        <span class="n">restore_best_weights</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
    <span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">EarlyStopping</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">monitor</span> <span class="o">=</span> <span class="n">monitor</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">min_delta</span> <span class="o">=</span> <span class="n">min_delta</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">patience</span> <span class="o">=</span> <span class="n">patience</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">verbose</span> <span class="o">=</span> <span class="n">verbose</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">baseline</span> <span class="o">=</span> <span class="n">baseline</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">restore_best_weights</span> <span class="o">=</span> <span class="n">restore_best_weights</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">wait</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">stopped_epoch</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">state_dict</span> <span class="o">=</span> <span class="kc">None</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;auto&quot;</span><span class="p">,</span> <span class="s2">&quot;min&quot;</span><span class="p">,</span> <span class="s2">&quot;max&quot;</span><span class="p">]:</span>
            <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span>
                <span class="s2">&quot;EarlyStopping mode </span><span class="si">%s</span><span class="s2"> is unknown, &quot;</span>
                <span class="s2">&quot;fallback to auto mode.&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span>
                <span class="ne">RuntimeWarning</span><span class="p">,</span>
            <span class="p">)</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="s2">&quot;auto&quot;</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;min&quot;</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">less</span>
        <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;max&quot;</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">greater</span>  <span class="c1"># type: ignore[assignment]</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">if</span> <span class="n">_is_metric</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor</span><span class="p">):</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">greater</span>  <span class="c1"># type: ignore[assignment]</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">less</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">greater</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">min_delta</span> <span class="o">*=</span> <span class="mi">1</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">min_delta</span> <span class="o">*=</span> <span class="o">-</span><span class="mi">1</span>

    <span class="k">def</span> <span class="nf">on_train_begin</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
        <span class="c1"># Allow instances to be re-used</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">wait</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">stopped_epoch</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">baseline</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">best</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">baseline</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">best</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">Inf</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">less</span> <span class="k">else</span> <span class="o">-</span><span class="n">np</span><span class="o">.</span><span class="n">Inf</span>

    <span class="k">def</span> <span class="nf">on_epoch_end</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">logs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">metric</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
    <span class="p">):</span>
        <span class="n">current</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_monitor_value</span><span class="p">(</span><span class="n">logs</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">current</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="k">return</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span><span class="p">(</span><span class="n">current</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_delta</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">best</span><span class="p">):</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">best</span> <span class="o">=</span> <span class="n">current</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">wait</span> <span class="o">=</span> <span class="mi">0</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">best_epoch</span> <span class="o">=</span> <span class="n">epoch</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">restore_best_weights</span><span class="p">:</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">state_dict</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">())</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">wait</span> <span class="o">+=</span> <span class="mi">1</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">wait</span> <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">patience</span><span class="p">:</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">stopped_epoch</span> <span class="o">=</span> <span class="n">epoch</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">early_stop</span> <span class="o">=</span> <span class="kc">True</span>

    <span class="k">def</span> <span class="nf">on_train_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stopped_epoch</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbose</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="nb">print</span><span class="p">(</span>
                <span class="sa">f</span><span class="s2">&quot;Best Epoch: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">best_epoch</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">1</span><span class="si">}</span><span class="s2">. Best </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor</span><span class="si">}</span><span class="s2">: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">best</span><span class="si">:</span><span class="s2">.5f</span><span class="si">}</span><span class="s2">&quot;</span>
            <span class="p">)</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">restore_best_weights</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">state_dict</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbose</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Restoring model weights from the end of the best epoch&quot;</span><span class="p">)</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">state_dict</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">get_monitor_value</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logs</span><span class="p">):</span>
        <span class="n">monitor_value</span> <span class="o">=</span> <span class="n">logs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">monitor_value</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span>
                <span class="s2">&quot;Early stopping conditioned on metric `</span><span class="si">%s</span><span class="s2">` &quot;</span>
                <span class="s2">&quot;which is not available. Available metrics are: </span><span class="si">%s</span><span class="s2">&quot;</span>
                <span class="o">%</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor</span><span class="p">,</span> <span class="s2">&quot;,&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">logs</span><span class="o">.</span><span class="n">keys</span><span class="p">()))),</span>
                <span class="ne">RuntimeWarning</span><span class="p">,</span>
            <span class="p">)</span>
        <span class="k">return</span> <span class="n">monitor_value</span>

    <span class="k">def</span> <span class="nf">__getstate__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="n">d</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span>
        <span class="n">self_dict</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">d</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">d</span> <span class="k">if</span> <span class="n">k</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;trainer&quot;</span><span class="p">,</span> <span class="s2">&quot;model&quot;</span><span class="p">]}</span>
        <span class="k">return</span> <span class="n">self_dict</span>

    <span class="k">def</span> <span class="nf">__setstate__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">):</span>
        <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span> <span class="o">=</span> <span class="n">state</span>
</code></pre></div></td></tr></table></div>
              </details>



  <div class="doc doc-children">











  </div>

    </div>

</div>








  <aside class="md-source-file">
    
    
    
      
  
  <span class="md-source-file__fact">
    <span class="md-icon" title="Contributors">
      
        <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 5.5A3.5 3.5 0 0 1 15.5 9a3.5 3.5 0 0 1-3.5 3.5A3.5 3.5 0 0 1 8.5 9 3.5 3.5 0 0 1 12 5.5M5 8c.56 0 1.08.15 1.53.42-.15 1.43.27 2.85 1.13 3.96C7.16 13.34 6.16 14 5 14a3 3 0 0 1-3-3 3 3 0 0 1 3-3m14 0a3 3 0 0 1 3 3 3 3 0 0 1-3 3c-1.16 0-2.16-.66-2.66-1.62a5.54 5.54 0 0 0 1.13-3.96c.45-.27.97-.42 1.53-.42M5.5 18.25c0-2.07 2.91-3.75 6.5-3.75s6.5 1.68 6.5 3.75V20h-13zM0 20v-1.5c0-1.39 1.89-2.56 4.45-2.9-.59.68-.95 1.62-.95 2.65V20zm24 0h-3.5v-1.75c0-1.03-.36-1.97-.95-2.65 2.56.34 4.45 1.51 4.45 2.9z"/></svg>
      
    </span>
    <nav>
      
        <a href="mailto:javierrodriguezzaurin@javiers-macbook-pro.local">Javier Rodriguez Zaurin</a>, 
        <a href="mailto:mulinka.pavol@gmail.com">Pavol Mulinka</a>
    </nav>
  </span>

    
    
  </aside>





                
              </article>
            </div>
          
          
<script>var target=document.getElementById(location.hash.slice(1));target&&target.name&&(target.checked=target.name.startsWith("__tabbed_"))</script>
        </div>
        
      </main>
      
        <footer class="md-footer">
  
  <div class="md-footer-meta md-typeset">
    <div class="md-footer-meta__inner md-grid">
      <div class="md-copyright">
  
    <div class="md-copyright__highlight">
      Javier Zaurin and Pavol Mulinka
    </div>
  
  
    Made with
    <a href="https://squidfunk.github.io/mkdocs-material/" target="_blank" rel="noopener">
      Material for MkDocs
    </a>
  
</div>
      
        <div class="md-social">
  
    
    
    
    
      
      
    
    <a href="https://jrzaurin.medium.com/" target="_blank" rel="noopener" title="jrzaurin.medium.com" class="md-social__link">
      <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 640 512"><!--! Font Awesome Free 6.6.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2024 Fonticons, Inc.--><path d="M180.5 74.262C80.813 74.262 0 155.633 0 256s80.819 181.738 180.5 181.738S361 356.373 361 256 280.191 74.262 180.5 74.262m288.25 10.646c-49.845 0-90.245 76.619-90.245 171.095s40.406 171.1 90.251 171.1 90.251-76.619 90.251-171.1H559c0-94.503-40.4-171.095-90.248-171.095Zm139.506 17.821c-17.526 0-31.735 68.628-31.735 153.274s14.2 153.274 31.735 153.274S640 340.631 640 256c0-84.649-14.215-153.271-31.742-153.271Z"/></svg>
    </a>
  
</div>
      
    </div>
  </div>
</footer>
      
    </div>
    <div class="md-dialog" data-md-component="dialog">
      <div class="md-dialog__inner md-typeset"></div>
    </div>
    
    
    <script id="__config" type="application/json">{"base": "..", "features": ["navigation.tabs", "navigation.tabs.sticky", "navigation.indexes", "navigation.expand", "toc.integrate"], "search": "../assets/javascripts/workers/search.6ce7567c.min.js", "translations": {"clipboard.copied": "Copied to clipboard", "clipboard.copy": "Copy to clipboard", "search.result.more.one": "1 more on this page", "search.result.more.other": "# more on this page", "search.result.none": "No matching documents", "search.result.one": "1 matching document", "search.result.other": "# matching documents", "search.result.placeholder": "Type to start searching", "search.result.term.missing": "Missing", "select.version": "Select version"}}</script>
    
    
      <script src="../assets/javascripts/bundle.83f73b43.min.js"></script>
      
        <script src="../stylesheets/extra.js"></script>
      
        <script src="../javascripts/mathjax.js"></script>
      
        <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
      
        <script src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
      
    
  </body>
</html>